{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {
    "code_folding": [],
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "购买苹果后同时黄油的置信度为：0.694\n",
      "购买黄油后同时香蕉蕉的置信度为：0.659\n",
      "购买面包后同时香蕉蕉的置信度为：0.630\n",
      "购买黄油后同时苹果的置信度为：0.610\n",
      "购买苹果后同时香蕉蕉的置信度为：0.583\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "\n",
    "valid_rules = defaultdict(int)\n",
    "invalid_rules = defaultdict(int)\n",
    "num_occurance = defaultdict(int)\n",
    "\n",
    "data_filename = \"affinity_dataset.txt\"\n",
    "datas = np.loadtxt(data_filename)\n",
    "features = [\"面包\", \"牛奶\", \"黄油\", \"苹果\", \"香蕉蕉\"]\n",
    "# 苹果和香蕉的联系\n",
    "\n",
    "# A 和 B之间的联系\n",
    "def connect(indexA, indexB):\n",
    "    buy_A_num = 0\n",
    "    for sample in datas:\n",
    "        if sample[indexA] == 0:\n",
    "            continue\n",
    "        buy_A_num += 1\n",
    "        \n",
    "        if(sample[indexB] == 1):\n",
    "            valid_rules[(indexA, indexB)] += 1\n",
    "        else:\n",
    "            invalid_rules[(indexA, indexB)] += 1\n",
    "    return buy_A_num\n",
    "\n",
    "\n",
    "def get_confidence():\n",
    "    confidence = defaultdict(float)\n",
    "    for premise, feature in valid_rules.keys():\n",
    "        rule = (premise, feature)\n",
    "        confidence[rule] = valid_rules[rule] / \\\n",
    "            (valid_rules[rule]+invalid_rules[rule])\n",
    "        #print(\"购买{0}后同时{1}的置信度为：{2:0.3f}\".format(features[rule[0]],features[rule[1]],confidence[rule]))\n",
    "        \n",
    "    return confidence\n",
    "if __name__ == \"__main__\":\n",
    "    for i in range(len(features)):\n",
    "        for j in range(len(features)):\n",
    "            if(i == j):\n",
    "                continue\n",
    "            connect(i,j)\n",
    "    confidence = get_confidence()\n",
    "\n",
    "    from operator import itemgetter\n",
    "    sort_dict = sorted(confidence.items(),key=itemgetter(1),reverse=True)\n",
    "    for index in range(5):\n",
    "        rule = sort_dict[index][0]\n",
    "        print(\"购买{0}后同时{1}的置信度为：{2:0.3f}\".format(features[rule[0]],features[rule[1]],confidence[rule]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
